from netCDF4 import Dataset
from wrf import getvar
import math
import os
import glob
import pandas as pd
from datetime import datetime

class WRFProcessor:
    def __init__(self, input_dir='../', output_dir='../'):
        """Initialize WRF data processor with input and output directories"""
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.time_idx = 0  # Time index for data extraction
        
        # Define CAMS stations positions (grid coordinates)
        self.stations_positions = {
            'CAMS404': ([67], [67]),
            'CAMS1052': ([68], [60]),
            'CAMS695': ([59], [63]),
            'CAMS53': ([57], [51]),
            'CAMS409': ([50], [53]),
            'CAMS8': ([76], [65]),
            'CAMS416': ([56], [67]),
            'CAMS1': ([64], [73]),
            'CAMS603': ([63], [76]),
            'CAMS403': ([61], [70]),
            'CAMS167': ([61], [72]),
            'CAMS1029': ([59], [70]),
            'CAMS169': ([58], [70]),
            'CAMS670': ([58], [70]),
            'CAMS1020': ([56], [70]),
            'CAMS1049': ([58], [73]),
            'CAMS698': ([107], [52])
        }
        
        # Variables to track
        self.variables = ['WSPD', 'temperature', 'pm25', 'ozone']
        
        # Initialize storage dictionary
        self.stations_dict = self._initialize_stations_dict()

    def _initialize_stations_dict(self):
        """Initialize nested dictionary structure for storing station data"""
        stations_dict = {}
        for station in self.stations_positions.keys():
            stations_dict[station] = {var: [] for var in self.variables}
        return stations_dict

    def get_wrf_files(self):
        """Get sorted list of WRF output files"""
        os.chdir(self.input_dir)
        ncfiles = glob.glob(os.path.join(self.input_dir, 'wrfout_d03*'))
        return sorted(ncfiles)

    def process_file(self, ncfile):
        """Process a single WRF output file"""
        data = Dataset(ncfile)
        
        # Extract basic meteorological variables
        u10 = getvar(data, "U10", self.time_idx)
        v10 = getvar(data, "V10", self.time_idx)
        surface_temp = getvar(data, "T2", self.time_idx)
        
        # Extract air quality variables
        pm_2_5 = getvar(data, "PM2_5_DRY", self.time_idx)[0]
        ozone = getvar(data, "o3", self.time_idx)[0]
        
        # Process data for each station
        for station_name, station_pos in self.stations_positions.items():
            # Calculate wind speed from components
            wspd = math.sqrt(u10[station_pos]**2 + v10[station_pos]**2)
            
            # Extract values for this station
            station_data = {
                'WSPD': float(wspd),
                'temperature': float(surface_temp[station_pos]),
                'pm25': float(pm_2_5[station_pos]),
                'ozone': float(ozone[station_pos])
            }
            
            # Store values in dictionary
            for var, value in station_data.items():
                self.stations_dict[station_name][var].append(value)

    def process_all_files(self):
        """Process all WRF files in the input directory"""
        ncfiles = self.get_wrf_files()
        for ncfile in ncfiles:
            self.process_file(ncfile)

    def export_to_csv(self):
        """Export processed data to CSV files"""
        # Create output directory if it doesn't exist
        os.makedirs(self.output_dir, exist_ok=True)
        
        # Export data for each station
        for station in self.stations_dict:
            df = pd.DataFrame(self.stations_dict[station])
            
            # Add timestamp column (assuming hourly data)
            df['timestamp'] = pd.date_range(
                start=datetime.now().replace(hour=0, minute=0),
                periods=len(df),
                freq='H'
            )
            
            # Save to CSV
            output_file = os.path.join(self.output_dir, f'{station}_data.csv')
            df.to_csv(output_file, index=False)
            print(f"Saved data for {station} to {output_file}")

def main():
    # Initialize processor
    processor = WRFProcessor()
    
    # Process all files
    processor.process_all_files()
    
    # Export results
    processor.export_to_csv()

if __name__ == "__main__":
    main()